{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ee939d6d",
   "metadata": {},
   "source": [
    "# Unit Tests Generator for Code\n",
    "\n",
    "Tool for generating unit tests for code using a local Llama LLM model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d61ff2a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports \n",
    "\n",
    "from openai import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1410b7dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8391d095",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model\n",
    "\n",
    "MODEL = \"llama3.2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f55ad72",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_user_prompt(code_snippet=\"\"\"\n",
    "def calculate_total_price(price, tax_rate):\n",
    "  return price * (1 + tax_rate)\n",
    "\"\"\"):\n",
    "  return f\"\"\"\n",
    "Please generate unit tests for the following code. Maximize on coverage. Take care of edge cases as well.\n",
    "\n",
    "```python\n",
    "{code_snippet}\n",
    "```\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b0e6e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_prompt = create_user_prompt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "648e61f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(user_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af787e3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_unit_tests_for_code(user_prompt):\n",
    "  system_message = (\n",
    "    \"You are a helpful assistant that generates unit tests for code.\"\n",
    "  )\n",
    "  response = openai.chat.completions.create(\n",
    "    model=MODEL,\n",
    "    messages=[\n",
    "      {\"role\": \"system\", \"content\": system_message},\n",
    "      {\"role\": \"user\", \"content\": user_prompt}\n",
    "    ]\n",
    "  )\n",
    "  result = response.choices[0].message.content\n",
    "\n",
    "  return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e740c9e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = create_unit_tests_for_code(user_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9b030c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
